"""
可视化 CentralFoveaAttention 注意力点坐标和权重的完整脚本
包含从 HDF5 读取的原始图像、注意力热力图和高斯中心坐标标记
"""
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import h5py
import webdataset as wds
from torch.utils.data import DataLoader
import sys
import random 


def find_latest_checkpoint(out_dir):
    ckpts = [f for f in os.listdir(out_dir) if f.startswith("ckpt_") and f.endswith(".pt")]
    if not ckpts:
        return None, 0
    latest = max(ckpts, key=lambda x: int(x.split("_")[1].split(".")[0]))
    return os.path.join(out_dir, latest), int(latest.split("_")[1].split(".")[0])

# 添加项目路径
sys.path.append("mindeye2_src")

# 导入模型组件
from models import CentralFoveaAttention, VoxelVAE

# ==================== 1. 超参数设置 ====================
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_path = 'dataset'
subj = 2
batch_size = 1
test_batch_size = 3000
num_session = 40
num_samples_per_epoch = 750 * num_session
num_iterations_per_epoch = num_samples_per_epoch // batch_size
output_dir = f"./subj0{subj}_vae_eye"
visualization_output_dir = f"./attention_visualizations_0{subj}"  # 新增：可视化图片保存目录

# 创建可视化输出目录
os.makedirs(visualization_output_dir, exist_ok=True)

# ==================== 2. 数据集准备 ====================
clip_emb_file = h5py.File(f'{data_path}/clip_embeddings.hdf5', 'r')
clip_embeddings = clip_emb_file['embeddings']

# 打开图像 HDF5 文件
image_file = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = image_file['images']  # 假设数据集名为 'images'

def my_split_by_node(urls): return urls

test_url = f"{data_path}/wds/subj0{subj}/new_test/0.tar"
test_data = wds.WebDataset(test_url, resampled=False, nodesplitter=my_split_by_node) \
   .shuffle(750, initial=1500, rng=random.Random(42)) \
   .decode("torch") \
   .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy",
           olds_behav="olds_behav.npy") \
   .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl  = DataLoader(test_data, batch_size=test_batch_size, shuffle=False,
                      drop_last=False, pin_memory=True, num_workers=4, prefetch_factor=2)

# ==================== 3. 初始化模型 ====================
# 假设我们已经知道 voxel 数量，这里从之前的训练脚本中获取
with h5py.File(f'{data_path}/betas_all_subj0{subj}_fp32_renorm.hdf5', 'r') as f:
    voxels = torch.tensor(f['betas'][:]).float()
num_voxels = voxels.shape[-1]

model = VoxelVAE(
    num_voxels=num_voxels,
    token_dim=768,
    num_tokens=257,
    hidden_dim=256,
    n_blocks=2,
    drop=0.15
).to(device)

attn = CentralFoveaAttention(embed_dim=768, grid_size=16).to(device)

# ==================== 4. 加载模型权重 ====================
print(output_dir)
ckpt_path, _ = find_latest_checkpoint(output_dir)
print(ckpt_path)
if ckpt_path and os.path.isfile(ckpt_path):
    print(f"🔄 Loading checkpoint from {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    attn.load_state_dict(ckpt["attn"])
else:
    raise FileNotFoundError("No checkpoint found, please train the model first.")

# ==================== 5. 可视化函数 ====================
def visualize_attention_on_image(original_image,
                                 image_rep,
                                 attn_module,
                                 save_path=None):
    with torch.no_grad():
        mu, w = attn_module.get_mu_w(image_rep.float())  # w: (B,257)
        mu_np = mu[0].cpu().numpy()                      # (2,)
        patch_w = w[0, 1:].cpu().numpy()                 # 忽略 cls token，取 (256,)

        # 坐标映射到 224×224
        img_size = 224
        focus_x = int((mu_np[0] + 1) / 2 * (img_size - 1))
        focus_y = int((mu_np[1] + 1) / 2 * (img_size - 1))

        # 处理原始图像
        if original_image.ndim == 3 and original_image.shape[0] == 3:
            img_disp = np.transpose(original_image, (1, 2, 0))
        else:
            img_disp = original_image
        if img_disp.dtype == np.float16:
            img_disp = img_disp.astype(np.float32)
        if img_disp.max() > 1.0:
            img_disp /= 255.0

        # 绘制注意力权重（16x16 网格）
        fig, ax = plt.subplots(1, 1, figsize=(12, 12))
        ax.imshow(img_disp.detach().cpu())

        # 将 patch_w reshape 为 16x16 并绘制为 heatmap
        patch_w_2d = patch_w.reshape(16, 16)
        heat = ax.imshow(patch_w_2d, cmap='jet', alpha=0.5,
                         extent=[0, 224, 224, 0],   # 对齐到 224x224 像素空间
                         interpolation='bilinear')   # 避免平滑

        # 标记高斯中心
        ax.plot(focus_x, focus_y, 'r+', markersize=20, markeredgewidth=3)
        #ax.set_title('Image with CentralFoveaAttention Map')
        ax.axis('off')

        # 添加 colorbar
        #cbar = plt.colorbar(heat, ax=ax, shrink=0.8)
        #cbar.set_label('Attention Weight', rotation=270, labelpad=20)
        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Visualization saved to {save_path}")
        else:
            plt.show()



# ==================== 6. 主执行逻辑 ====================
if __name__ == "__main__":
    model.eval()
    attn.eval()
    
    with torch.no_grad():
        test_clip_emb = None
        test_image = None
        for behav, _, _, _ in test_dl:
            image_idx = behav[:, 0, 0].cpu().long()
            unique_image, _ = torch.unique(image_idx, return_inverse=True)
            for im in unique_image:
                locs = torch.where(im == image_idx)[0]
                # 保证 3 个位置
                if len(locs) == 1:
                    locs = locs.repeat(3)
                elif len(locs) == 2:
                    locs = locs.repeat(2)[:3]
                assert len(locs) == 3

                # 收集 clip embedding 和 voxel
                clip_emb = torch.tensor(clip_embeddings[im]).unsqueeze(0)
                img_emb = torch.tensor(images[im]).unsqueeze(0)
                if test_clip_emb is None:
                    test_clip_emb = clip_emb
                    test_image = img_emb
                else:
                    test_clip_emb = torch.vstack((test_clip_emb, clip_emb))
                    test_image = torch.vstack((test_image,img_emb))
            for sample_idx in range(1000):
                # 生成保存路径
                save_path = os.path.join(visualization_output_dir, f"attention_visualization_{sample_idx:04d}.png")
                original_image = test_image[sample_idx].float()
                embedding0 = test_clip_emb[sample_idx].unsqueeze(0).float().to(device)
                # 可视化注意力
                visualize_attention_on_image(original_image, embedding0, attn, save_path=save_path)
                
            
    clip_emb_file.close()
    image_file.close()